Skip to content

Conversation

@lhutton1
Copy link
Contributor

Implement SymbolOpInterface on tosa.variable so that it's declaration is automatically inserted into its parents SymbolTable.

Verifiers for tosa.variable_read/write can now look up the symbol and guarantee it exists, and duplicate names are caught at creation time. Previously this was completed by walking the graph which could be inefficient.

Unfortunately, the Symbol trait expects to find a symbol name via a hard-coded attribute name "sym_name". Therefore, "name" is renamed to"sym_name" and a getName() wrapper is provided for backwards compatibility.

This change also restricts tosa.variable declarations to ops that carry a SymbolTable (e.g. modules), rather than allowing them to be placed inside a func.func.

Note: EXT-VARIABLE is an experimental extension in the TOSA specification, so is not subject to backwards compatibility guarantees.

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

Implement SymbolOpInterface on tosa.variable so that it's declaration is automatically inserted into its parents SymbolTable.

Verifiers for tosa.variable_read/write can now look up the symbol and guarantee it exists, and duplicate names are caught at creation time. Previously this was completed by walking the graph which could be inefficient.

Unfortunately, the Symbol trait expects to find a symbol name via a hard-coded attribute name "sym_name". Therefore, "name" is renamed to"sym_name" and a getName() wrapper is provided for backwards compatibility.

This change also restricts tosa.variable declarations to ops that carry a SymbolTable (e.g. modules), rather than allowing them to be placed inside a func.func.

Note: EXT-VARIABLE is an experimental extension in the TOSA specification, so is not subject to backwards compatibility guarantees.


Patch is 26.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153223.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td (+12-4)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+15-52)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+26-29)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+14-8)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+18-12)
  • (modified) mlir/test/Dialect/Tosa/variables.mlir (+77-55)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+33-23)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..41094748fb528 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -201,9 +201,9 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
 // and optional initial value. The builder will extract var_shape and element type
 // attributes from variable type.
 def Tosa_VariableOpBuilder : OpBuilder<
-  (ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
+  (ins "StringRef":$sym_name, "Type":$variable_type, "Attribute":$initial_value),
   [{
-    buildVariableOp($_builder, $_state, name, variable_type, initial_value);
+    buildVariableOp($_builder, $_state, sym_name, variable_type, initial_value);
   }]>;
 
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index c8f2907f8dd1b..270daf8d5668e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -18,6 +18,7 @@
 include "mlir/IR/OpBase.td"
 
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
@@ -82,7 +83,7 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
 //===----------------------------------------------------------------------===//
 // Operator: variable
 //===----------------------------------------------------------------------===//
-def Tosa_VariableOp : Tosa_Op<"variable", []> {
+def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> {
   let summary = "Defines a variable";
 
   let description = [{
@@ -91,7 +92,10 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
   }];
 
   let arguments = (ins
-    SymbolNameAttr:$name,
+    // Note: "sym_name" is used as opposed to "name" in the specification,
+    // since a Symbol must be named "sym_name" for it to be recognised by
+    // the containing SymbolTable.
+    SymbolNameAttr:$sym_name,
     IndexElementsAttr:$var_shape,
     TypeAttr:$type,
     OptionalAttr<AnyAttr>:$initial_value
@@ -105,14 +109,18 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
   let hasCustomAssemblyFormat = 1;
 
   let assemblyFormat = [{
-    $name
+    $sym_name
     attr-dict
     custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
   }];
 
   let builders = [Tosa_VariableOpBuilder];
 
-  let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    ::llvm::StringRef getName() {
+      return getSymName();
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3cafb199d2db3..89f1a03e3abe4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -667,56 +667,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
   return shapeAdaptor.getNumElements() == 1 ? success() : failure();
 }
 
-// Returns the first declaration point prior to this operation or failure if
-// not found.
-static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
-                                                    StringRef symName) {
-  ModuleOp module = op->getParentOfType<ModuleOp>();
-  tosa::VariableOp varOp = nullptr;
-
-  // TODO: Adopt SymbolTable trait to Varible ops.
-  // Currently, the variable's definition point is searched via walk(),
-  // starting from the top-level ModuleOp and stopping at the point of use. Once
-  // TOSA control flow and variable extensions reach the complete state, may
-  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
-  // the search to a TOSA specific graph traversal over the IR structure.
-  module.walk([&](Operation *tempOp) {
-    // Reach this op itself.
-    if (tempOp == op) {
-      return WalkResult::interrupt();
-    }
-
-    if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
-      if (symName == tosaOp.getName()) {
-        varOp = tosaOp;
-        return WalkResult::interrupt();
-      }
-    }
-
-    return WalkResult::advance();
-  });
-
-  if (varOp)
-    return varOp;
-
-  return failure();
-}
-
 template <typename T>
 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
-  StringRef symName = op.getName();
-  FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
-  if (failed(varOp))
+  Operation *symTableOp =
+      op->template getParentWithTrait<OpTrait::SymbolTable>();
+  if (!symTableOp)
+    // If the operation is not the scope of a symbol table, we cannot
+    // verify it against it's declaration.
+    return success();
+
+  SymbolTable symTable(symTableOp);
+  const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
+
+  // Verify prior declaration
+  if (!varOp)
     return op->emitOpError("'")
-           << symName << "' has not been declared by 'tosa.variable'";
+           << op.getName() << "' has not been declared by 'tosa.variable'";
 
   // Verify type and shape
-  auto variableType = getVariableType(varOp.value());
+  auto variableType = getVariableType(varOp);
   if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
                                  "the input tensor")
           .failed())
     return failure();
-
   return success();
 }
 
@@ -1180,7 +1153,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
   ArrayRef<int64_t> shape = shapedType.getShape();
   auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
 
-  result.addAttribute("name", nameAttr);
+  result.addAttribute("sym_name", nameAttr);
   result.addAttribute("var_shape", varShapeAttr);
   result.addAttribute("type", elementTypeAttr);
   result.addAttribute("initial_value", initialValue);
@@ -3908,16 +3881,6 @@ LogicalResult tosa::SelectOp::verify() {
   return success();
 }
 
-LogicalResult tosa::VariableOp::verify() {
-  StringRef symName = getName();
-  FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
-  if (succeeded(varOp))
-    return emitOpError("illegal to have multiple declaration of '")
-           << symName << "'";
-
-  return success();
-}
-
 LogicalResult tosa::VariableReadOp::verify() {
   if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
           .failed())
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 3bccb32c5b9f4..c024f578f2946 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -564,64 +564,61 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
 
 // -----
 
-func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   tosa.variable @stored_var : tensor<*xi8>
   // expected-error@+1 {{custom op 'tosa.variable' expected ranked type}}
-  return
 }
 
 // -----
 
-func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   // expected-error@+1 {{elements literal type must have static shape}}
   tosa.variable @stored_var = dense<0> : tensor<*xi8>
   // expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
-  return
-}
-
-// -----
-
-func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
-  tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
-  return
 }
 
 // -----
 
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
-  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
-  return
+  func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
+    %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
+    return
+  }
 }
 
 // -----
 
-func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
-  %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
-  return
+  func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
+    %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
+    return
+  }
 }
 
 // -----
 
-func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+module {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
-  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
-  return
+  func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+    // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
+    tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
+    return
+  }
 }
 
 // -----
 
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+module {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
-  tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
-  return
+  func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
+    tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
+    return
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 3154f541e0519..a7bb0138103ab 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -310,21 +310,27 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
 }
 
 // -----
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   // expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
-  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
-  return
+
+  func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
+    %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
+    return
+  }
 }
 
 // -----
-func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   // expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
-  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
-  return
+
+  func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
+    tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
+    return
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0184d2b05f0ee..633705b41f8e7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1088,14 +1088,17 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %
 
 // -----
 
-func.func @test_variable_read_write_tensor_size_invalid() -> () {
+module {
   // expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
   tosa.variable @stored_var : tensor<536870912xf32>
-  // expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.variable_read @stored_var : tensor<536870912xf32>
-  // expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
-  return
+
+  func.func @test_variable_read_write_tensor_size_invalid() -> () {
+    // expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+    %0 = tosa.variable_read @stored_var : tensor<536870912xf32>
+    // expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+    tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
+    return
+  }
 }
 
 // -----
@@ -1156,14 +1159,17 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
 
 // -----
 
-func.func @test_variable_read_write_rank_invalid() -> () {
+module {
   // expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
   tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
-  // expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
-  %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
-  // expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
-  tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
-  return
+
+  func.func @test_variable_read_write_rank_invalid() -> () {
+    // expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
+    %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
+    // expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
+    tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
+    return
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 9953eb375d3ac..0c104e8e8d7ea 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -3,76 +3,98 @@
 
 
 // -----
-// CHECK-LABEL:   @test_variable_scalar(
-// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<f32>) {
-func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
-  // CHECK:           tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
+
+module {
+  // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
   tosa.variable @stored_var = dense<3.14> : tensor<f32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
-  %0 = tosa.variable_read @stored_var : tensor<f32>
-  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
-  tosa.variable_write @stored_var, %1 : tensor<f32>
-  return
+
+  // CHECK-LABEL: @test_variable_scalar(
+  // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+  func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
+    // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+    %0 = tosa.variable_read @stored_var : tensor<f32>
+    // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+    tosa.variable_write @stored_var, %1 : tensor<f32>
+    return
+  }
 }
 
+
 // -----
-// CHECK-LABEL:   @test_variable_tensor(
-// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
-func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
-  // CHECK:           tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+
+module {
+  // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
-  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
-  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
-  tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
-  return
+
+  // CHECK-LABEL: @test_variable_tensor(
+  // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+  func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
+    // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+    %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+    // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+    %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+    // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+    tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+    return
+  }
 }
 
 // -----
-// CHECK-LABEL:   @test_variable_scalar_no_initial_value(
-// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<f32>) {
-func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
-  // CHECK:           tosa.variable @stored_var : tensor<f32>
+
+module {
+  // CHECK: tosa.variable @stored_var : tensor<f32>
   tosa.variable @stored_var : tensor<f32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
-  %0 = tosa.variable_read @stored_var : tensor<f32>
-  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
-  tosa.variable_write @stored_var, %1 : tensor<f32>
-  return
+
+  // CHECK-LABEL: @test_variable_scalar_no_initial_value(
+  // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+  func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
+    // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+    %0 = tosa.variable_read @stored_var : tensor<f32>
+    // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+    tosa.variable_write @stored_var, %1 : tensor<f32>
+    return
+  }
 }
 
 // -----
-// CHECK-LABEL:   @test_variable_tensor_no_initial_value(
-// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
-func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () {
-  // CHECK:           tosa.variable @stored_var : tensor<2x4x8xi32>
+
+module {
+  // CHECK: tosa.variable @stored_var : tensor<2x4x8xi32>
   tosa.variable @stored_var : tensor<2x4x8xi32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
-  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
-  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  // CHECK:           tosa.v...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

Implement SymbolOpInterface on tosa.variable so that it's declaration is automatically inserted into its parents SymbolTable.

Verifiers for tosa.variable_read/write can now look up the symbol and guarantee it exists, and duplicate names are caught at creation time. Previously this was completed by walking the graph which could be inefficient.

Unfortunately, the Symbol trait expects to find a symbol name via a hard-coded attribute name "sym_name". Therefore, "name" is renamed to"sym_name" and a getName() wrapper is provided for backwards compatibility.

This change also restricts tosa.variable declarations to ops that carry a SymbolTable (e.g. modules), rather than allowing them to be placed inside a func.func.

Note: EXT-VARIABLE is an experimental extension in the TOSA specification, so is not subject to backwards compatibility guarantees.


Patch is 26.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153223.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td (+12-4)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+15-52)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+26-29)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+14-8)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+18-12)
  • (modified) mlir/test/Dialect/Tosa/variables.mlir (+77-55)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+33-23)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..41094748fb528 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -201,9 +201,9 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
 // and optional initial value. The builder will extract var_shape and element type
 // attributes from variable type.
 def Tosa_VariableOpBuilder : OpBuilder<
-  (ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
+  (ins "StringRef":$sym_name, "Type":$variable_type, "Attribute":$initial_value),
   [{
-    buildVariableOp($_builder, $_state, name, variable_type, initial_value);
+    buildVariableOp($_builder, $_state, sym_name, variable_type, initial_value);
   }]>;
 
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index c8f2907f8dd1b..270daf8d5668e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -18,6 +18,7 @@
 include "mlir/IR/OpBase.td"
 
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
@@ -82,7 +83,7 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
 //===----------------------------------------------------------------------===//
 // Operator: variable
 //===----------------------------------------------------------------------===//
-def Tosa_VariableOp : Tosa_Op<"variable", []> {
+def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> {
   let summary = "Defines a variable";
 
   let description = [{
@@ -91,7 +92,10 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
   }];
 
   let arguments = (ins
-    SymbolNameAttr:$name,
+    // Note: "sym_name" is used as opposed to "name" in the specification,
+    // since a Symbol must be named "sym_name" for it to be recognised by
+    // the containing SymbolTable.
+    SymbolNameAttr:$sym_name,
     IndexElementsAttr:$var_shape,
     TypeAttr:$type,
     OptionalAttr<AnyAttr>:$initial_value
@@ -105,14 +109,18 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
   let hasCustomAssemblyFormat = 1;
 
   let assemblyFormat = [{
-    $name
+    $sym_name
     attr-dict
     custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
   }];
 
   let builders = [Tosa_VariableOpBuilder];
 
-  let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    ::llvm::StringRef getName() {
+      return getSymName();
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3cafb199d2db3..89f1a03e3abe4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -667,56 +667,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
   return shapeAdaptor.getNumElements() == 1 ? success() : failure();
 }
 
-// Returns the first declaration point prior to this operation or failure if
-// not found.
-static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
-                                                    StringRef symName) {
-  ModuleOp module = op->getParentOfType<ModuleOp>();
-  tosa::VariableOp varOp = nullptr;
-
-  // TODO: Adopt SymbolTable trait to Varible ops.
-  // Currently, the variable's definition point is searched via walk(),
-  // starting from the top-level ModuleOp and stopping at the point of use. Once
-  // TOSA control flow and variable extensions reach the complete state, may
-  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
-  // the search to a TOSA specific graph traversal over the IR structure.
-  module.walk([&](Operation *tempOp) {
-    // Reach this op itself.
-    if (tempOp == op) {
-      return WalkResult::interrupt();
-    }
-
-    if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
-      if (symName == tosaOp.getName()) {
-        varOp = tosaOp;
-        return WalkResult::interrupt();
-      }
-    }
-
-    return WalkResult::advance();
-  });
-
-  if (varOp)
-    return varOp;
-
-  return failure();
-}
-
 template <typename T>
 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
-  StringRef symName = op.getName();
-  FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
-  if (failed(varOp))
+  Operation *symTableOp =
+      op->template getParentWithTrait<OpTrait::SymbolTable>();
+  if (!symTableOp)
+    // If the operation is not the scope of a symbol table, we cannot
+    // verify it against it's declaration.
+    return success();
+
+  SymbolTable symTable(symTableOp);
+  const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
+
+  // Verify prior declaration
+  if (!varOp)
     return op->emitOpError("'")
-           << symName << "' has not been declared by 'tosa.variable'";
+           << op.getName() << "' has not been declared by 'tosa.variable'";
 
   // Verify type and shape
-  auto variableType = getVariableType(varOp.value());
+  auto variableType = getVariableType(varOp);
   if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
                                  "the input tensor")
           .failed())
     return failure();
-
   return success();
 }
 
@@ -1180,7 +1153,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
   ArrayRef<int64_t> shape = shapedType.getShape();
   auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
 
-  result.addAttribute("name", nameAttr);
+  result.addAttribute("sym_name", nameAttr);
   result.addAttribute("var_shape", varShapeAttr);
   result.addAttribute("type", elementTypeAttr);
   result.addAttribute("initial_value", initialValue);
@@ -3908,16 +3881,6 @@ LogicalResult tosa::SelectOp::verify() {
   return success();
 }
 
-LogicalResult tosa::VariableOp::verify() {
-  StringRef symName = getName();
-  FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
-  if (succeeded(varOp))
-    return emitOpError("illegal to have multiple declaration of '")
-           << symName << "'";
-
-  return success();
-}
-
 LogicalResult tosa::VariableReadOp::verify() {
   if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
           .failed())
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 3bccb32c5b9f4..c024f578f2946 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -564,64 +564,61 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten
 
 // -----
 
-func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   tosa.variable @stored_var : tensor<*xi8>
   // expected-error@+1 {{custom op 'tosa.variable' expected ranked type}}
-  return
 }
 
 // -----
 
-func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   // expected-error@+1 {{elements literal type must have static shape}}
   tosa.variable @stored_var = dense<0> : tensor<*xi8>
   // expected-error@+1 {{custom op 'tosa.variable' expected attribute}}
-  return
-}
-
-// -----
-
-func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () {
-  tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}}
-  tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8>
-  return
 }
 
 // -----
 
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
-  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
-  return
+  func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}}
+    %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16>
+    return
+  }
 }
 
 // -----
 
-func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
-  %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
-  return
+  func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}}
+    %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32>
+    return
+  }
 }
 
 // -----
 
-func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+module {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
-  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
-  return
+  func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
+    // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}}
+    tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16>
+    return
+  }
 }
 
 // -----
 
-func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+module {
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
-  tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
-  return
+  func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}}
+    tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8>
+    return
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 3154f541e0519..a7bb0138103ab 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -310,21 +310,27 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> {
 }
 
 // -----
-func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   // expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
-  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
-  return
+
+  func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}}
+    %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8>
+    return
+  }
 }
 
 // -----
-func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
+module {
   // expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}}
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8>
-  // expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
-  tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
-  return
+
+  func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () {
+    // expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}}
+    tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8>
+    return
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 0184d2b05f0ee..633705b41f8e7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1088,14 +1088,17 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %
 
 // -----
 
-func.func @test_variable_read_write_tensor_size_invalid() -> () {
+module {
   // expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
   tosa.variable @stored_var : tensor<536870912xf32>
-  // expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.variable_read @stored_var : tensor<536870912xf32>
-  // expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
-  return
+
+  func.func @test_variable_read_write_tensor_size_invalid() -> () {
+    // expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+    %0 = tosa.variable_read @stored_var : tensor<536870912xf32>
+    // expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
+    tosa.variable_write @stored_var, %0 : tensor<536870912xf32>
+    return
+  }
 }
 
 // -----
@@ -1156,14 +1159,17 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1:
 
 // -----
 
-func.func @test_variable_read_write_rank_invalid() -> () {
+module {
   // expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}}
   tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
-  // expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
-  %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
-  // expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
-  tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
-  return
+
+  func.func @test_variable_read_write_rank_invalid() -> () {
+    // expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}}
+    %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32>
+    // expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}}
+    tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32>
+    return
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir
index 9953eb375d3ac..0c104e8e8d7ea 100644
--- a/mlir/test/Dialect/Tosa/variables.mlir
+++ b/mlir/test/Dialect/Tosa/variables.mlir
@@ -3,76 +3,98 @@
 
 
 // -----
-// CHECK-LABEL:   @test_variable_scalar(
-// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<f32>) {
-func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
-  // CHECK:           tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
+
+module {
+  // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor<f32>
   tosa.variable @stored_var = dense<3.14> : tensor<f32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
-  %0 = tosa.variable_read @stored_var : tensor<f32>
-  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
-  tosa.variable_write @stored_var, %1 : tensor<f32>
-  return
+
+  // CHECK-LABEL: @test_variable_scalar(
+  // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+  func.func @test_variable_scalar(%arg0: tensor<f32>) -> () {
+    // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+    %0 = tosa.variable_read @stored_var : tensor<f32>
+    // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+    tosa.variable_write @stored_var, %1 : tensor<f32>
+    return
+  }
 }
 
+
 // -----
-// CHECK-LABEL:   @test_variable_tensor(
-// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
-func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
-  // CHECK:           tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
+
+module {
+  // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
   tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
-  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
-  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
-  tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
-  return
+
+  // CHECK-LABEL: @test_variable_tensor(
+  // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
+  func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () {
+    // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+    %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
+    // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+    %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
+    // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32>
+    tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32>
+    return
+  }
 }
 
 // -----
-// CHECK-LABEL:   @test_variable_scalar_no_initial_value(
-// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<f32>) {
-func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
-  // CHECK:           tosa.variable @stored_var : tensor<f32>
+
+module {
+  // CHECK: tosa.variable @stored_var : tensor<f32>
   tosa.variable @stored_var : tensor<f32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
-  %0 = tosa.variable_read @stored_var : tensor<f32>
-  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  // CHECK:           tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
-  tosa.variable_write @stored_var, %1 : tensor<f32>
-  return
+
+  // CHECK-LABEL: @test_variable_scalar_no_initial_value(
+  // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<f32>) {
+  func.func @test_variable_scalar_no_initial_value(%arg0: tensor<f32>) -> () {
+    // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<f32>
+    %0 = tosa.variable_read @stored_var : tensor<f32>
+    // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %1 = "tosa.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<f32>
+    tosa.variable_write @stored_var, %1 : tensor<f32>
+    return
+  }
 }
 
 // -----
-// CHECK-LABEL:   @test_variable_tensor_no_initial_value(
-// CHECK-SAME:                        %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) {
-func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () {
-  // CHECK:           tosa.variable @stored_var : tensor<2x4x8xi32>
+
+module {
+  // CHECK: tosa.variable @stored_var : tensor<2x4x8xi32>
   tosa.variable @stored_var : tensor<2x4x8xi32>
-  // CHECK:           %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32>
-  %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32>
-  // CHECK:           %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32>
-  // CHECK:           tosa.v...
[truncated]

@lhutton1
Copy link
Contributor Author

@lhutton1
Copy link
Contributor Author

lhutton1 commented Sep 1, 2025

ping for reviews on this

attr-dict
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
}];

let builders = [Tosa_VariableOpBuilder];

let hasVerifier = 1;
let extraClassDeclaration = [{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be getSymbolName() to avoid ambiguity ?

Copy link
Contributor Author

@lhutton1 lhutton1 Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was to avoid breaking dependent code that might have previously relied on getName() (before the renaming). In theory, we could avoid providing this, breaking the API, and make it clear in the PR that this is a breaking change

Copy link
Contributor

@udaya-ranga udaya-ranga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Tai78641
Copy link
Contributor

LGTM

Implement SymbolOpInterface on tosa.variable so that it's declaration is
automatically inserted into its parents SymbolTable.

Verifiers for tosa.variable_read/write can now look up the symbol and
guarantee it exists, and duplicate names are caught at creation time.
Previously this was completed by walking the graph which could be
inefficient.

Unfortunately, the Symbol trait expects to find a symbol name
via a hard-coded attribute name "sym_name". Therefore, "name" is renamed
to"sym_name" and a getName() wrapper is provided for backwards
compatibility.

This change also restricts tosa.variable declarations to ops that carry
a SymbolTable (e.g. modules), rather than allowing them to be placed
inside a func.func.

Note: EXT-VARIABLE is an experimental extension in the TOSA
specification, so is not subject to backwards compatibility
guarantees.

Change-Id: I00a3f8f3b3b4f68cb3c120fe2c928d7b74b214cb
@lhutton1 lhutton1 force-pushed the variable-symbol-trait branch from 1509d7c to da051fe Compare October 4, 2025 10:09
@lhutton1 lhutton1 merged commit fee71a3 into llvm:main Oct 6, 2025
9 checks passed
@lhutton1 lhutton1 deleted the variable-symbol-trait branch October 6, 2025 12:14
@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 6, 2025

LLVM Buildbot has detected a new failure on builder clang-aarch64-sve2-vla-2stage running on linaro-g4-02 while building mlir at step 12 "ninja check 2".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/199/builds/6334

Here is the relevant piece of the build log for the reference
Step 12 (ninja check 2) failure: stage 2 checked (failure)
******************** TEST 'LLVM :: tools/llvm-exegesis/AArch64/no-aliasing-ld-str.s' FAILED ********************
Exit Code: -7

Command Output (stdout):
--
# RUN: at line 3
/home/tcwg-buildbot/worker/clang-aarch64-sve2-vla-2stage/stage2/bin/llvm-exegesis -mtriple=aarch64 -mcpu=neoverse-v2 -mode=latency --dump-object-to-disk=%d --opcode-name=FMOVWSr --benchmark-phase=assemble-measured-code 2>&1
# executed command: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla-2stage/stage2/bin/llvm-exegesis -mtriple=aarch64 -mcpu=neoverse-v2 -mode=latency --dump-object-to-disk=%d --opcode-name=FMOVWSr --benchmark-phase=assemble-measured-code
# .---command stdout------------
# | Check generated assembly with: /usr/bin/objdump -d %d
# | ---
# | mode:            latency
# | key:
# |   instructions:
# |     - 'FMOVWSr S25 W6'
# |     - 'FCVTZSUXSr X8 S25'
# |   config:          ''
# |   register_initial_values:
# |     - 'W6=0x0'
# |     - 'FPCR=0x0'
# | cpu_name:        neoverse-v2
# | llvm_triple:     aarch64
# | min_instructions: 10000
# | measurements:    []
# | error:           actual measurements skipped.
# | info:            Repeating two instructions
# | assembled_snippet: 06008052080080D208441BD5D900271E2803389ED900271E2803389ED900271E2803389ED900271E2803389EC0035FD6
# | ...
# `-----------------------------
# RUN: at line 4
/home/tcwg-buildbot/worker/clang-aarch64-sve2-vla-2stage/stage2/bin/llvm-objdump -d %d > /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla-2stage/stage2/test/tools/llvm-exegesis/AArch64/Output/no-aliasing-ld-str.s.tmp.s
# executed command: /home/tcwg-buildbot/worker/clang-aarch64-sve2-vla-2stage/stage2/bin/llvm-objdump -d %d
# .---redirected output from '/home/tcwg-buildbot/worker/clang-aarch64-sve2-vla-2stage/stage2/test/tools/llvm-exegesis/AArch64/Output/no-aliasing-ld-str.s.tmp.s'
# | 
# | %d:	file format elf64-littleaarch64
# | 
# | Disassembly of section .text:
# | 
# | 0000000000000000 <foo>:
# |        0: 52800006     	mov	w6, #0x0                // =0
# |        4: d2800008     	mov	x8, #0x0                // =0
# |        8: d51b4408     	msr	FPCR, x8
# |        c: 1e2700d9     	fmov	s25, w6
# |       10: 9e380328     	fcvtzs	x8, s25
# |       14: 1e2700d9     	fmov	s25, w6
# |       18: 9e380328     	fcvtzs	x8, s25
# |       1c: 1e2700d9     	fmov	s25, w6
# |       20: 9e380328     	fcvtzs	x8, s25
# |       24: 1e2700d9     	fmov	s25, w6
# |       28: 9e380328     	fcvtzs	x8, s25
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants